from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from pandas import DataFrame
from scipy import interpolate
from scipy.stats import pearsonr
from scipy.stats import spearmanr

import re

def is_float(str):
    return bool(re.match(r'^[-+]?(\d+\.\d*|\.\d+|\d+)([eE][-+]?\d+)?$', str))


 # # Read each interval from the CSV file, and the distinguishing method is that each interval must be in ascending order
def getInterval(fre_power_filepath:str):
    spectrum = pd.read_csv(fre_power_filepath)
    spectrum['group'] = (spectrum['freq'].shift(1) >= spectrum['freq']).cumsum()
    grouped_spectrum = spectrum.groupby('group')
    freq_list = []
    power_list = []

    for name, group in grouped_spectrum:
        freq_list.append(group['freq'].tolist())
        power_list.append(group['power'].tolist())

    print('At begining, There are {} curves in file1'.format(len(freq_list)))

    for i in range(len(freq_list)-1,0,-1):
        tmp_list = freq_list[i]
        if(len(tmp_list)<3):
            # print('deleting')
            del freq_list[i]
            del power_list[i]

    return freq_list, power_list

 # # Returns a function simulated by scatter (linear function simulation)
def getF(freq_list:list, power_list:list):
    f = interpolate.interp1d(freq_list, power_list, fill_value="extrapolate")
    return f

 # # Based on the contents of two files, return the y value corresponding to x with a fixed and identical interval for each interval, with interval values of [0, 0.5]
def alignPoints(filepath1:str, filepath2:str):

    freq_list_list_1, power_list_list_1 = getInterval(filepath1)
    freq_list_list_2, power_list_list_2 = getInterval(filepath2)

    y1listlist, y2listlist = [], []

    short_length = len(freq_list_list_1) if len(freq_list_list_1)<len(freq_list_list_2) else len(freq_list_list_2)

    for i in range(short_length):
        freq_list1 = freq_list_list_1[i]
        power_list1 = power_list_list_1[i]
        freq_list2 = freq_list_list_2[i]
        power_list2 = power_list_list_2[i]

        func1 = getF(freq_list1, power_list1)
        func2 = getF(freq_list2, power_list2)

        # interpolate
        x = np.linspace(0, 0.5, 1000)
        y1 = func1(x)
        y2 = func2(x)
        y1listlist.append(y1)
        y2listlist.append(y2)

    return x, y1listlist, y2listlist



 # # Calculate auc value for every frequency interval
def getPSO(filepath1:str, filepath2:str):
    area_floor_list, area_roof_list, pso_list = [], [], []

    xlist, y1listlist, y2listlist = alignPoints(filepath1, filepath2)

    for i in range(len(y1listlist)):
        y1list = y1listlist[i]
        y2list = y2listlist[i]

        y1list = [abs(i) for i in y1list]
        y2list = [abs(i) for i in y2list]
        ylists = []
        ylists.append(y1list)
        ylists.append(y2list)


        y_intersection = np.amin(ylists, axis=0)
        y_roof = np.amax(ylists, axis=0)
        area_floor = np.trapz(y_intersection, xlist)
        area_roof = np.trapz(y_roof, xlist)

        area_floor_list.append(area_floor)
        area_roof_list.append(area_roof)
        pso_list.append(round(area_floor / area_roof, 4))

    return area_floor_list, area_roof_list, pso_list


def getSpearmanr(filepath1:str, filepath2:str):
    xlist, y1listlist, y2listlist = alignPoints(filepath1, filepath2)
    corr_list = []

    for i in range(len(y1listlist)):
        y1list = y1listlist[i]
        y2list = y2listlist[i]

        corr, _ = spearmanr(y1list, y2list)
        corr_list.append(corr)
    return corr_list

 # # Calculate Pearson Corelation for every frequency interval
def getPearson(filepath1:str, filepath2:str):
    xlist, y1listlist, y2listlist = alignPoints(filepath1, filepath2)
    corr_list = []

    for i in range(len(y1listlist)):
        y1list = y1listlist[i]
        y2list = y2listlist[i]

        corr, _ = pearsonr(y1list, y2list)
        corr_list.append(corr)
    return corr_list

 # # Calculate the similarity between two spectra using Spectral Angle Mapper
def getSAM(filepath1:str, filepath2:str):
    xlist, y1listlist, y2listlist = alignPoints(filepath1, filepath2)
    sam_list = []

    for i in range(len(y1listlist)):
        y1list = y1listlist[i]
        y2list = y2listlist[i]
        ylists = []
        ylists.append(y1list)
        ylists.append(y2list)

        # Normalize the spectra
        y1list /= np.linalg.norm(y1list)
        y2list /= np.linalg.norm(y2list)

        # Calculate the dot product
        dot_product = np.dot(y1list, y2list)

        # Calculate the SAM similarity
        sam_similarity = np.arccos(dot_product) / np.pi

        sam_list.append(sam_similarity)

    return sam_list



# Following code is to calculate pso and other metrics between many csvs
def cal():
    data_sources = ('news', 'story', 'wiki')
    model_types = ('6.7b', '125m')
    text_length_tuple = (0, 1, 2, 3, 4)
    ans_str = ''

    for model_type in model_types:
        for data_source in data_sources:
            total_pso_list = []
            total_corr_list = []
            total_sam_list = []
            total_spearmanr_list = []
            for text_length in text_length_tuple:
                original_filename = 'webtext.train.model=.' + data_source + '_' + str(text_length) + '.fft.csv'
                generated_filename = 'webtext.train_opt_' + model_type + '_top_50_' + data_source + '.sorted.split.' + str(
                    text_length * 200) + '.fft.csv'
                area_floor_list, area_roof_list, pso_list = getPSO(original_filename, generated_filename)
                corr_list = getPearson(original_filename, generated_filename)
                sam_list = getSAM(original_filename, generated_filename)
                spearmanr_list = getSpearmanr(original_filename, generated_filename)

                avg_pso = sum(pso_list) / len(pso_list)
                avg_corr = sum(corr_list) / len(corr_list)
                avg_sam = sum(sam_list) / len(sam_list)
                avg_spearmanr = sum(spearmanr_list) / len(spearmanr_list)
                tmp_str = model_type + '_' + data_source + '_' + str(text_length) + '\t' + str(avg_pso) + '\t' + str(
                    avg_corr) + '\t' + str(
                    avg_sam) + '\t' + str(avg_spearmanr) + '\n'
                ans_str = ans_str + tmp_str
                print(tmp_str)

                total_pso_list.extend(pso_list)
                total_corr_list.extend(corr_list)
                total_sam_list.extend(sam_list)
                total_spearmanr_list.extend(spearmanr_list)

            total_avg_pso = sum(total_pso_list) / len(total_pso_list)
            total_avg_corr = sum(total_corr_list) / len(total_corr_list)
            total_avg_sam = sum(total_sam_list) / len(total_sam_list)
            total_avg_spearmanr = sum(total_spearmanr_list) / len(total_spearmanr_list)

            tmp_str = model_type + '_' + data_source + '\t' + str(avg_pso) + '\t' + str(
                avg_corr) + '\t' + str(
                avg_sam) + '\t' + str(avg_spearmanr) + '\n'
            ans_str = ans_str + tmp_str
            print(tmp_str)

    with open('Ans2.txt', 'w') as f:
        f.write(ans_str)


def calSingle():
    original_filename = 'contrastive_gold.fft.csv'
    generated_filename = 'model_b.fft.csv'
    area_floor_list, area_roof_list, pso_list = getPSO(original_filename, generated_filename)
    corr_list = getPearson(original_filename, generated_filename)
    sam_list = getSAM(original_filename, generated_filename)
    spearmanr_list = getSpearmanr(original_filename, generated_filename)

    avg_pso = sum(pso_list) / len(pso_list)
    avg_corr = sum(corr_list) / len(corr_list)
    avg_sam = sum(sam_list) / len(sam_list)
    avg_spearmanr = sum(spearmanr_list) / len(spearmanr_list)
    tmp_str =  str(avg_pso) + '\t' + str(
        avg_corr) + '\t' + str(
        avg_sam) + '\t' + str(avg_spearmanr) + '\n'
    print(tmp_str)

calSingle()

